%% Estimation. Load data

clear
% workdir = '/Applications/Dropbox/RnD_KK/tradecost_replication';
workdir = '/Dropbox/RnD_KK/tradecost_replication';

datadir = strcat(workdir,'/data_matlab');
% addpath = strcat(workdir,'/ReStat_code');
% addpath(genpath(addpath));
addpath(genpath(strcat(workdir,'/ReStat_code')))

thresh = '40';               % # entrants in a destinanation must be >= thresh

prm.s = 6;
data.year = '2004';

M = csvread(strcat(datadir,'/data_qty_',thresh,'thresh_',data.year,'.raw'));           % (IxP)xK
M2 = csvread(strcat(datadir,'/data_price_',thresh,'thresh_',data.year,'.raw'));           % (IxP)xK

% Load number of firms/obs per product
% numfirms is a vector with [product_id number of firms sigma]
data.numfirms = csvread(strcat(datadir,'/data_',data.year,'_count_',thresh,'thresh.raw'));  % Px1

% % Load country characteristics: Real GDP, gdp/capita, population, distance
% % This is not used in the estimation, but later on
fid = fopen(strcat(datadir,'/data_dist_gdp_',thresh,'thresh_',data.year,'.raw'));
N = textscan(fid, '%s %f %f %f %f','delimiter',',');
fclose(fid);
data.countries = N{1};
data.gdp = N{2};
data.cgdp = N{3};
data.dist = N{5};

data.P = size(data.numfirms,1);
data.K = size(M,2)-2;           % # countries
K = data.K; params.K = data.K;

% Data is originally in "long" format, so products appear sequentially.
% Rearrange to "wide" format.
% PS: We need to make sure that all products in the CSV file are actually
% used in estimation. I.e. that no products have less than $thresh firms in
% all destinations.
thresh2 = str2num(thresh);
start = 1;
for p=1:data.P
  stop = start+data.numfirms(p,2)-1;
  qty = M(start:stop,3:end);    % I_k x K Sales (quantity)
  price = M2(start:stop,3:end);
  entrants  = qty>0;            % (I_k x P) x K Export participation
  
  for n=1:data.K
    tmp = qty(:,n);
    tmp_p = price(:,n);
    if length(tmp(entrants(:,n)))>=thresh2           % Check if we want to use this product-dest pair
      data.usedata(n,p)=1;
      data.sales_actual{n,p} = tmp(entrants(:,n));
      data.lnX{n,p} = log(data.sales_actual{n,p});
      data.lnX_dev{n,p} = data.lnX{n,p} - mean(data.lnX{n,p});

      data.price{n,p} = tmp_p(entrants(:,n));
      
      data.priceinv{n,p} = 1./data.price{n,p};
      data.priceinv_dev{n,p} = data.priceinv{n,p} - mean(data.priceinv{n,p});      

      data.priceinv2{n,p} = data.price{n,p}.^(-2);
      data.priceinv2_dev{n,p} = data.priceinv2{n,p} - mean(data.priceinv2{n,p});      

      data.priceinv3{n,p} = data.price{n,p}.^(-3);
      data.priceinv3_dev{n,p} = data.priceinv3{n,p} - mean(data.priceinv3{n,p});      

      data.priceinv4{n,p} = data.price{n,p}.^(-4);
      data.priceinv4_dev{n,p} = data.priceinv4{n,p} - mean(data.priceinv4{n,p});      

      data.lnprice{n,p} = log(data.price{n,p});
      data.lnprice2{n,p} = data.lnprice{n,p}.^2;
      data.lnprice3{n,p} = data.lnprice{n,p}.^3;
      data.lnprice4{n,p} = data.lnprice{n,p}.^4;

      data.lnprice_dev{n,p} = data.lnprice{n,p} - mean(data.lnprice{n,p}); 
      data.lnprice2_dev{n,p} = data.lnprice2{n,p} - mean(data.lnprice2{n,p}); 
      data.lnprice3_dev{n,p} = data.lnprice3{n,p} - mean(data.lnprice3{n,p}); 
      data.lnprice4_dev{n,p} = data.lnprice4{n,p} - mean(data.lnprice4{n,p}); 
      
      data.priceinvlnp{n,p} = data.priceinv{n,p}.*data.lnprice{n,p};
      data.priceinvlnp_dev{n,p} = data.priceinvlnp{n,p} - mean(data.priceinvlnp{n,p});      
      
      data.lnValue{n,p} = data.lnX{n,p} + data.lnprice{n,p};
      data.price_mean(n,p) = mean(data.price{n,p});
      data.price_median(n,p) = median(data.price{n,p});
      
      % Calculate phi
      tmp = cov(data.lnX{n,p},data.lnprice{n,p});
      data.var_lnX(n,p) = tmp(1,1);
      data.var_lnprice(n,p) = tmp(2,2);
      data.cov_px(n,p) = tmp(1,2);
      data.corr_px(n,p) = corr(data.lnX{n,p},data.lnprice{n,p});

      phi(n,p) = tmp(1,2)./tmp(2,2) + prm.s;
      phisig(n,p) = tmp(1,2)./tmp(2,2);
      
      [data.ECDF{n,p},data.esales{n,p}] = ecdf(data.sales_actual{n,p});    
      data.entrants(n,p) = length(data.sales_actual{n,p});
      
      % Construct pct ratios
      s = data.lnX{n,p};
      pcts = prctile(s,[5:5:95]);
      if (pcts(1)==pcts(2)) pcts(2)=pcts(2)+.1; end;
      if (pcts(2)==pcts(3)) pcts(3)=pcts(3)+.1; end;
         
      data.emom(n,p,1) = (pcts(end)-pcts(end-1)) / (pcts(2)-pcts(1));
      data.emom(n,p,2) = (pcts(end-1)-pcts(end-2)) / (pcts(3)-pcts(2));
      data.emom(n,p,3) = var(s);

    else data.usedata(n,p)=0; end;
  end  
  start = stop+1;
end

% phi should only be product specific
data.phi = mean(phi);        % 1 x P
data.phisig = phisig;        % N x P   

% Number of destinations per product
data.K2 = sum(data.usedata)' 
data.avgentrants = mean(data.entrants(logical(data.usedata)));

% Load weights (export value per product-destination)
data.weights = csvread(strcat(datadir,'/weights_',thresh,'thresh_',data.year,'.raw'));           % KxP
data.weights = data.weights/sum(data.weights(:));

% Load value & weight per product
tmp = csvread(strcat(datadir,'/weight_value_',thresh,'thresh_',data.year,'.raw'));           % Kx3
data.weight_value = tmp(:,2)./tmp(:,1);

% Load value & weight per product - only SE
tmp = csvread(strcat(datadir,'/weight_value_SE_',thresh,'thresh_',data.year,'.raw'));           % Kx3
data.weight_value_SE = tmp(:,2)./tmp(:,1);

tmp = csvread(strcat(datadir,'/weight_qty_',thresh,'thresh_',data.year,'.raw'));           % Kx3
data.weight_qty = tmp(:,1)./tmp(:,2);


%%% ......................%%%
%%% TABLE 2: ROW 5 and 6  %%%
%%% ......................%%%
clear M N;
disp('Year'); disp(data.year);
disp('Number of products'); disp(data.P);
disp('Number of countries'); disp(data.K);

% Number of usable product-destinations
data.num_proddest = sum(sum(data.usedata));
disp('Number of product-destinations'); disp(data.num_proddest);


%disp('Number of products usable in estimation per destination');
%disp(sum(data.usedata'));


if any(sum(data.usedata)==0) error('Some products are have no useable destination'); end;
disp('Data loaded');


%% NLS, estimate separately for each p
res = [];
K = data.K;
P = data.P;
options = optimset('Display','iter','TolFun',1e-8,'TolX',1e-8);

tic
    for p=1:data.P
          prm.p = p
          disp('Number of destinations'); disp(data.K2(p));
          f = @(b) imo_2011_NLSFE(b,data,prm,0);
          
          tmp = data.price_median(:,p);
          tmp = tmp(tmp>0);     % Median prices from destinations with positive sales
          nn = length(tmp);
          b0 = [log(tmp'*.1) ones(1,nn)*5 -2]; 
          lb = [log(tmp'*.01) ones(1,nn)*-10 -20];          
          ub = [log(tmp'*2) ones(1,nn)*30 20]; 
          % t max 2x price
          
          [b,fval,exitflag,output,lambda,grad,hessian]  = fmincon(f,b0,[],[],[],[],lb,ub,[],options);
          if exitflag<0 error('No estimate'); end
          lnt = b(1:data.K2(p))
          a = b(data.K2(p)+1:2*data.K2(p));
          phisig1 = b(end);
          %phisig2 = b(end)
          sderr = sqrt(diag(inv(hessian)));
          sderr_lnt = sderr(1:data.K2(p));
          sderr_phisig1 = sderr(end);

          % Find the correct indices for t
          i = 1; t_arr = []; sderr_lnt_arr = []; 
          for n=1:data.K
            if (data.usedata(n,p)==1) 
                lnt_arr(n) = lnt(i); 
                sderr_lnt_arr(n) = sderr_lnt(i); 
                a_arr(n) = a(i);
                i = i+1; 
            else
                lnt_arr(n) = NaN;
                sderr_lnt_arr(n) = NaN;
                a_arr(n) = NaN;

            end
          end
          res.lntmat(:,p) = lnt_arr';
          res.amat(:,p) = a_arr';
          res.phisig1(p) = phisig1;
          res.sderr_lnt(:,p) = sderr_lnt_arr';
          res.sderr_phisig1(p) = sderr_phisig1;
    end
toc

res.tstat_lnt = res.lntmat./res.sderr_lnt;

%save(strcat(workdir,'/matlab_model2/res_2004_40thresh_sep3'),'res');
%save(strcat(workdir,'/matlab_model2/res_2004_20thresh_sep3'),'res');

% save(strcat(workdir,'/data_matlab/res_2004_',thresh,'thresh_sep3'),'res');

%% Display results - estimation for every p

%%% ....................%%%
%%% TABLE 3: column 1   %%%
%%% ....................%%%


% load(strcat(workdir,'/data_matlab/res_2004_',thresh,'thresh_sep3'));

data.price_median(data.price_median==0) = NaN;
data.price_mean(data.price_mean==0) = NaN;

j = exp(res.lntmat)./data.price_median;

jj = j(data.usedata==1);
pct = prctile(jj,[5 95]);
kk = jj(jj>pct(1) & jj<pct(2));

disp('Grand mean TC, winsorized');
disp(mean(kk));

disp('Weighted mean trade costs relative to median price, winsorized'); 
ww = data.weights(data.usedata==1);
ss = jj.*ww;
pct2 = prctile(ss,[5 95]);
disp(sum(ss(ss>pct2(1) & ss<pct2(2))));

disp('Median trade cost relative to median price');
disp(median(kk));

disp('Standard deviation of trade costs/median price');
disp(std(kk));


%% NLS Joint estimation

% see estimation_short_restat.m


%% Display results : Joint estimation 

%%% ....................%%%
%%% TABLE 2: ROW 1      %%%
%%% ....................%%%

load(strcat(workdir,'/data_matlab/res_2004_',thresh,'thresh_Restat7'));
%load(strcat(workdir,'/data_matlab/res_2004_20thresh_Restat7'));  
%load(strcat(workdir,'/data_matlab/res_2004_40thresh_Restat7'));  % BASELINE

close all;
savefigures = 0;

data.price_median(data.price_median==0) = NaN;
data.price_mean(data.price_mean==0) = NaN;

j = exp(res.lntmat)./data.price_median;
k = exp(res.lntmat)./data.price_mean;

jj = j(data.usedata==1);
pct = prctile(jj,[5 95]);
kk = jj(jj>pct(1) & jj<pct(2));


%%% TC
disp('Grand mean TC, winsorized');
disp(mean(kk));

disp('Weighted mean trade costs relative to median price, winsorized'); 
ww = data.weights(data.usedata==1);
ss = jj.*ww;
pct2 = prctile(ss,[5 95]);
disp(sum(ss(ss>pct2(1) & ss<pct2(2))));

disp('Median trade cost relative to median price');
disp(median(kk));

disp('Standard deviation of trade costs/median price');
disp(std(kk));






ppct_kk = prctile(kk,[75 25]);
disp('75/25 percentile of TC'); disp(ppct_kk(1)/ppct_kk(2))

disp('Trade costs US versus NL')
idxUS = find(strcmp(data.countries,'US'));
idxNL = find(strcmp(data.countries,'NL'));
disp(exp(res.lnt_n(idxUS)-res.lnt_n(idxNL)));

ppct_p = prctile(exp(res.lnt_p),[75 25]);
ppct_n = prctile(exp(res.lnt_n),[75 25]);

disp('75/25 percentile of t_p'); disp(ppct_p(1)/ppct_p(2))
disp('75/25 percentile of t_n'); disp(ppct_n(1)/ppct_n(2))

%%% .................%%%
%%% FIGURE 4         %%%
%%% .................%%%
% Plot country FE vs distance
%

% Winsorize tau matrix
low = 10; high=100-low;
ii = res.lnt_n;
sel_n = ii > prctile(ii,low) & ii < prctile(ii,high);
ii = res.lnt_p;
sel_p = ii > prctile(ii,low) & ii < prctile(ii,high);

figure(1);
idx = res.lnt_n(sel_n)';
xx = log(data.dist(sel_n));
cty = data.countries(sel_n);
scatter(xx,idx,'Marker','none')
for j = 1:length(idx), text(xx(j),idx(j),cty(j)), end; grid on;
xlabel('Distance, km (logs)'); 
ylabel('$\ln\widetilde{t}_{n}$','Interpreter','latex'); 
lsline;
if (savefigures) saveas(gcf,strcat(workdir,'/graphs/res_joint_distanceFE.eps')); end;

% Regression of country FE against distance
X = [ones(sum(sel_n),1) log(data.dist(sel_n)) log(data.gdp(sel_n)) log(data.cgdp(sel_n))];
[b,bint,r] = regress(res.lnt_n(sel_n)',X);
disp('Coeff (const, dist, gdp, gdp/capita) and 95% conf interval');
[b bint]
disp('Standard error distance'); disp((bint(2,2)-b(2))/tinv(0.975,19));


%%% .................%%%
%%% FIGURE 2         %%%
%%% .................%%%
% Plot distribution of t/p in Sweden
figure(2);
idxSE = find(strcmp(data.countries,'SE'));
j = res.lntmat - log(data.price_median);
tcosts = exp(j(idxSE,:));
pctSE = prctile(tcosts,[5 95]);
tmp = tcosts(tcosts>pctSE(1) & tcosts<pctSE(2));

subplot(2,1,1); 
[f,xi] = ksdensity(tmp);
subplot(1,2,1); plot(xi,f);
xlabel('TC_{SE}'); 
ylabel('Density estimate');
axis([0 max(tmp) 0 max(f)]);

% % Plot density phisig (slope parameter)
limit_l = prctile(res.phisig1_p(:),5);
limit_h = prctile(res.phisig1_p(:),95);
tmp = res.phisig1_p(res.phisig1_p>limit_l & res.phisig1_p<limit_h);
[f,xi] = ksdensity(tmp);
subplot(1,2,2); plot(xi,f); 
xlabel('\sigma_{k}'); 
%xlabel('$\widetilde{\phi }_{1k}$','Interpreter','latex'); 
ylabel('Density estimate');
axis([limit_l limit_h 0 max(f)]);
if (savefigures) saveas(gcf,strcat(workdir,'/graphs/res_joint_TC_and_phisig_density.eps')); end;
if (savefigures) saveas(gcf,strcat(workdir,'/graphs/res_joint_TC_and_phisig_density.png')); end;

%%% ....................%%%
%%% TABLE 2: ROW 2      %%%
%%% ....................%%%

disp('Mean phi1'); disp(mean(res.phisig1_p(sel_p)));
disp('Median phi1'); disp(median(res.phisig1_p(sel_p)));
disp('Stdev phi1'); disp(std(res.phisig1_p(sel_p)));

disp('Weighted mean phi1'); 
ww = data.weights(idxSE,:);
ww = ww/sum(ww);
ss = res.phisig1_p.*ww;
pct3 = prctile(ss,[5 95]);
disp(sum(ss(ss>pct3(1) & ss<pct3(2))));



%%% .................%%%
%%% FIGURE 5         %%%
%%% .................%%%
% Scatter products FEs and weight/value
%

hs8name = cellstr(int2str(data.numfirms(:,1)));
hs8name = hs8name(sel_p);
yy = log(tcosts(tcosts>pctSE(1) & tcosts<pctSE(2)));
xx = log(data.weight_value_SE);
xx = xx(tcosts>pctSE(1) & tcosts<pctSE(2));

figure(3);
scatter(xx,yy,'Marker','none'); grid on;
for j = 1:length(xx), text(xx(j),yy(j),hs8name(j),'FontSize',6), end; grid on;
xlabel('Weight/value (logs)'); ylabel('TC (logs)'); 
lsline;
if (savefigures) saveas(gcf,strcat(workdir,'/graphs/res_joint_productFE.eps')); end;

disp('Corr - pval weight/value - tcosts (logs)'); 
[corr pval] = corrcoef(xx',yy,'rows','pairwise');
disp([corr(1,2) pval(1,2)]);


%
% Which products have high trade costs relative to price?
% 


pct = prctile(tcosts,95);
tmp = [data.numfirms(:,1) tcosts'];
tmp2 = tmp(tmp(:,2)>=pct,:);
disp('Products with > 95 percentile tc/price'); disp(tmp2(:,1));
disp('Value'); disp(tmp2(:,2));

% Find products with t<yy percentile
pct = prctile(tcosts,5);
tmp2 = tmp(tmp(:,2)<pct,:);
disp('Products with < 5 percentile tc/price'); disp(tmp2(:,1));
disp('Value'); disp(tmp2(:,2));


%%% ................%%%
%%% TABLE 2: ROW 4  %%%
%%% ................%%%

% Model fit. 
%

% R squared
allx = [];
for n=1:data.K
    for p=1:data.P
        xx = data.lnX_dev{n,p};
        allx = [allx; xx];
    end
end

SStot = sum((allx-mean(allx)).^2);

prm.idx_t_n = 1:data.K;
prm.idx_t_p = data.K+1:data.K+data.P-1;
prm.idx_phisig1_p = data.K+data.P:data.K+2*data.P-1;
prm.idx_a_p = data.K+2*data.P:data.K+3*data.P-2;
prm.idx_a_n =data.K+3*data.P-1:2*data.K+3*data.P-2;

disp('Objective function'); O = imo_2011_NLS_all3FE(res.b,data,prm,0);
disp(O);
disp('R-sq'); disp(1-O./SStot);


%%% .................%%%
%%% FIGURE 3         %%%
%%% .................%%%

% Show the estimated and actual price-quantity for a few products

figure(4);

tmp = data.numfirms(:,1);
p = find(tmp==73269000);
disp('HS code'); disp(data.numfirms(p,1));
destidx = find(data.usedata(:,p));
a=1;
for n=destidx'
  lnprice = data.lnprice_dev{n,p};
  priceinv = data.priceinv_dev{n,p};
  price = data.price{n,p};
  lnX = data.lnX_dev{n,p};
  subplot(ceil(length(destidx)/2),2,a); 
  
  % Scatter data
  hold on; sc1 = scatter(log(data.price{n,p}),data.lnX{n,p},'x'); 
  lnXhat = res.a(n,p) + res.phisig1_p(p)*log(data.price{n,p} + exp(res.lntmat(n,p)));         
  
  % Scatter model
  sc2 = scatter(log(data.price{n,p}),lnXhat,'.','LineWidth',2);
  xlabel('log f.o.b. price'); ylabel('log X'); title(data.countries(n));
  a=a+1;
end

% Legend
sh = subplot(ceil(length(destidx)/2),2,a);
p=get(sh,'position');
lh=legend(sh,[sc1 sc2],'Data','Model');
set(lh,'position',p);
axis(sh,'off');
set(lh, 'Box', 'off')
set(lh, 'Color', 'none')

if (savefigures) saveas(gcf,strcat(workdir,'/graphs/res_model_fit.eps')); end;
     